{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "55e50b25",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ce528bb2",
   "metadata": {},
   "source": [
    "# Unconditional Pixel Diffusion Training\n",
    "\n",
    "In this notebook, we will train a simple `PixelDiffusion` model in low resolution (64 by 64).\n",
    "\n",
    "The training should take about 10 hours.\n",
    "\n",
    "---\n",
    "\n",
    "Maps dataset from the pix2pix paper:\n",
    "```bash\n",
    "wget http://efrosgans.eecs.berkeley.edu/pix2pix/datasets/maps.tar.gz\n",
    "tar -xvf maps.tar.gz\n",
    "!rm maps.tar.gz\n",
    "```\n",
    "\n",
    "Ideally, you will download this dataset once and store it as `data/maps`. If you're running on colab, it's a good idea to download it once to your personal machine (it's only 240 MB) and then upload it to your colab space when you start a new notebook."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ec7d6dc4",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn.functional as F\n",
    "import torchvision\n",
    "import torchvision.transforms as T\n",
    "from torchvision.transforms import ToTensor\n",
    "from torch.utils.data import Dataset\n",
    "import pytorch_lightning as pl\n",
    "\n",
    "import numpy as np\n",
    "import matplotlib as mpl\n",
    "import matplotlib.pyplot as plt\n",
    "import imageio\n",
    "from skimage import io\n",
    "import os\n",
    "\n",
    "from src import *\n",
    "\n",
    "mpl.rcParams['figure.figsize'] = (8, 8)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "effe3e3a",
   "metadata": {},
   "outputs": [],
   "source": [
    "import kornia\n",
    "from kornia.utils import image_to_tensor\n",
    "import kornia.augmentation as KA\n",
    "\n",
    "class SimpleImageDataset(Dataset):\n",
    "    \"\"\"Dataset returning images in a folder.\"\"\"\n",
    "\n",
    "    def __init__(self,\n",
    "                 root_dir,\n",
    "                 transforms=None,\n",
    "                 paired=True,\n",
    "                 return_pair=False):\n",
    "        self.root_dir = root_dir\n",
    "        self.transforms = transforms\n",
    "        self.paired=paired\n",
    "        self.return_pair=return_pair\n",
    "        \n",
    "        # set up transforms\n",
    "        if self.transforms is not None:\n",
    "            if self.paired:\n",
    "                data_keys=2*['input']\n",
    "            else:\n",
    "                data_keys=['input']\n",
    "\n",
    "            self.input_T=KA.container.AugmentationSequential(\n",
    "                *self.transforms,\n",
    "                data_keys=data_keys,\n",
    "                same_on_batch=False\n",
    "            )   \n",
    "        \n",
    "        # check files\n",
    "        supported_formats=['webp','jpg']        \n",
    "        self.files=[el for el in os.listdir(self.root_dir) if el.split('.')[-1] in supported_formats]\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.files)\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        if torch.is_tensor(idx):\n",
    "            idx = idx.tolist()            \n",
    "\n",
    "        img_name = os.path.join(self.root_dir,\n",
    "                                self.files[idx])\n",
    "        image = image_to_tensor(io.imread(img_name))/255\n",
    "\n",
    "        if self.paired:\n",
    "            c,h,w=image.shape\n",
    "            slice=int(w/2)\n",
    "            image2=image[:,:,slice:]\n",
    "            image=image[:,:,:slice]\n",
    "            if self.transforms is not None:\n",
    "                out = self.input_T(image,image2)\n",
    "                image=out[0][0]\n",
    "                image2=out[1][0]\n",
    "        elif self.transforms is not None:\n",
    "            image = self.input_T(image)[0]\n",
    "\n",
    "        if self.return_pair:\n",
    "            return image2,image\n",
    "        else:\n",
    "            return image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "528e3f70",
   "metadata": {},
   "outputs": [],
   "source": [
    "CROP_SIZE=64\n",
    "\n",
    "inp_T=[        \n",
    "        KA.RandomCrop((CROP_SIZE,CROP_SIZE)),\n",
    "    ]\n",
    "\n",
    "train_ds=SimpleImageDataset('./data/maps/train',\n",
    "                            transforms=inp_T\n",
    "                     )\n",
    "\n",
    "test_ds=SimpleImageDataset('./data/maps/val',\n",
    "                           transforms=inp_T\n",
    "                          )\n",
    "\n",
    "for idx in range(16):\n",
    "    plt.subplot(4,4,1+idx)\n",
    "    plt.imshow(train_ds[idx].permute(1,2,0))\n",
    "    plt.axis('off')\n",
    "plt.tight_layout()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "82e372cc",
   "metadata": {},
   "source": [
    "### Model Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "04156f52",
   "metadata": {},
   "outputs": [],
   "source": [
    "model=PixelDiffusion(train_ds,\n",
    "                     lr=1e-4,\n",
    "                     batch_size=16)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1a83cab8",
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer = pl.Trainer(\n",
    "    max_steps=2e5,\n",
    "    callbacks=[EMA(0.9999)],\n",
    "    gpus = [0]\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "deafb040",
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer.fit(model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c4faf7bd",
   "metadata": {},
   "outputs": [],
   "source": [
    "B=8 # number of samples\n",
    "\n",
    "model.cuda()\n",
    "out=model(batch_size=B,shape=(64,64),verbose=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aeddad22",
   "metadata": {},
   "outputs": [],
   "source": [
    "for idx in range(out.shape[0]):\n",
    "    plt.subplot(1,len(out),idx+1)\n",
    "    plt.imshow(out[idx].detach().cpu().permute(1,2,0))\n",
    "    plt.axis('off')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "65df248d",
   "metadata": {},
   "source": [
    "By default, the `DDPM` sampler contained in the model is used, as above.\n",
    "\n",
    "However, you can use a `DDIM` sampler just as well to reduce the number of inference steps:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fac0abca",
   "metadata": {},
   "outputs": [],
   "source": [
    "B=8 # number of samples\n",
    "STEPS=200 # ddim steps\n",
    "\n",
    "ddim_sampler=DDIM_Sampler(STEPS,model.model.num_timesteps)\n",
    "\n",
    "model.cuda()\n",
    "out=model(batch_size=B,sampler=ddim_sampler,shape=(64,64),verbose=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "22c834ed",
   "metadata": {},
   "outputs": [],
   "source": [
    "for idx in range(out.shape[0]):\n",
    "    plt.subplot(1,len(out),idx+1)\n",
    "    plt.imshow(out[idx].detach().cpu().permute(1,2,0))\n",
    "    plt.axis('off')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "91e24258",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python",
   "pygments_lexer": "ipython3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
